# -*-Python-*-
# MoE model FLOP-matched to the AT5 baseline

import mesh_tensorflow.layers
import mesh_tensorflow.transformer.moe


# MoE parameters
MoE1D.num_experts = 32
MoE1D.activation = ["relu"]
MoE1D.hidden_size = 1536  # 1/2 d_ff in order to FLOP-match
MoE1D.dropout_rate = %dropout_rate

# Construct new layers
num_layers = 3

encoder/transformer.make_layer_stack.layers = [
  @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
  @mesh_tensorflow.transformer.transformer_layers.DenseReluDense,
  @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
  @mesh_tensorflow.transformer.transformer_layers.DenseReluDense,
  @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
  @mesh_tensorflow.transformer.transformer_layers.DenseReluDense,
  @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
  @mesh_tensorflow.transformer.moe.MoE1D,
]

decoder/transformer.make_layer_stack.layers = [
  @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
  @mesh_tensorflow.transformer.transformer_layers.EncDecAttention,
  @mesh_tensorflow.transformer.transformer_layers.DenseReluDense,
  @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
  @mesh_tensorflow.transformer.transformer_layers.EncDecAttention,
  @mesh_tensorflow.transformer.transformer_layers.DenseReluDense,
  @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
  @mesh_tensorflow.transformer.transformer_layers.EncDecAttention,
  @mesh_tensorflow.transformer.transformer_layers.DenseReluDense,
  @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
  @mesh_tensorflow.transformer.transformer_layers.EncDecAttention,
  @mesh_tensorflow.transformer.moe.MoE1D,
]

VarianceScalingInitializer.scale = 0.1
encoder/MoE1D.capacity_factor_train = 2.0
deccoder/MoE1D.capacity_factor_train = 2.0
encoder/MoE1D.capacity_factor_eval = 2.0
deccoder/MoE1D.capacity_factor_eval = 2.0

utils.get_variable_dtype.activation_dtype = "bfloat16"
